Machine learning to segment neutron images
Anders Kaestner, Beamline scientist - Neutron Imaging
Laboratory for Neutron Scattering and Imaging
Paul Scherrer Institut
If you want to run the notebook on your own computer, you'll need to perform the following step:
git clone https://github.com/ImagingLectures/MLSegmentation4NI.git
conda env create -f environment. yml -n MLSeg4NI
conda env activate MLSeg4NI
This lecture needs some modules to run. We import all of them here.
import matplotlib.pyplot as plt
import seaborn as sn
import numpy as np
import pandas as pd
import skimage.filters as flt
import skimage.io as io
import matplotlib as mpl
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix
from sklearn.datasets import make_blobs
from matplotlib.colors import ListedColormap
from lecturesupport import plotsupport as ps
import scipy.stats as stats
import astropy.io.fits as fits
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'png')
#plt.style.use('seaborn')
mpl.rcParams['figure.dpi'] = 300
Using TensorFlow backend.
import importlib
importlib.reload(ps);
Introduction to neutron imaging
Introduction to segmentation
Problematic segmentation tasks
A very abstract definition:
In most cases this is a two- or three-dimensional position (x,y,z coordinates) and a numeric value (intensity)
Images are great for qualitative analyses since our brains can quickly interpret them without large programming investements.
| Transmission through sample | X-ray attenuation | Neutron attenuation |

Start out with a simple image of a cross with added noise
$$ I(x,y) = f(x,y) $$fig,ax = plt.subplots(1,2,figsize=(12,6))
nx = 5; ny = 5;
xx, yy = np.meshgrid(np.arange(-nx, nx+1)/nx*2*np.pi, np.arange(-ny, ny+1)/ny*2*np.pi)
cross_im = 1.5*np.abs(np.cos(xx*yy))/(np.abs(xx*yy)+(3*np.pi/nx)) + np.random.uniform(-0.25, 0.25, size = xx.shape)
im=ax[0].imshow(cross_im, cmap = 'hot'); ax[0].set_title("Image")
ax[1].hist(cross_im.ravel(),bins=10); ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
Applying the threshold is a deceptively simple operation
$$ I(x,y) = \begin{cases} 1, & f(x,y)\geq0.40 \\ 0, & f(x,y)<0.40 \end{cases}$$threshold = 0.4; thresh_img = cross_im > threshold
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0].imshow(cross_im, cmap = 'hot', extent = [xx.min(), xx.max(), yy.min(), yy.max()]); ax[0].set_title("Image")
ax[0].plot(xx[np.where(thresh_img)]*0.9, yy[np.where(thresh_img)]*0.9,
'ks', markerfacecolor = 'green', alpha = 0.5,label = 'Threshold', markersize = 22); ax[0].legend(fontsize=12);
ax[1].hist(cross_im.ravel(),bins=10); ax[1].axvline(x=threshold,color='r',label='Threshold'); ax[1].legend(fontsize=12);
ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
The noise in neutron imaging mainly originates from the amount of captured neutrons.
This noise is Poisson distributed and the signal to noise ratio is
$$SNR=\frac{E[x]}{s[x]}\sim\frac{N}{\sqrt{N}}=\sqrt{N}$$
Woodland Encounter Bev Doolittle
Different types of limited data:
Data augmentation is a method modify your exisiting data to obtain variations of it.
Augmentation will be used to increase the training data in the root segmenation example in the end of this lecture.
Both augmented and simulated data should be combined with real data.
Transfer learning is a technique that uses a pre-trained network to
test_pts = pd.DataFrame(make_blobs(n_samples=200, random_state=2018)[
0], columns=['x', 'y'])
plt.plot(test_pts.x, test_pts.y, 'r.');
fig, ax = plt.subplots(1,3,figsize=(15,4.5))
for i in range(3) :
km = KMeans(n_clusters=i+2, random_state=2018); n_grp = km.fit_predict(test_pts)
ax[i].scatter(test_pts.x, test_pts.y, c=n_grp)
ax[i].set_title('{0} groups'.format(i+2))
tof = np.load('../data/tofdata.npy')
wtof = tof.mean(axis=2)
plt.imshow(wtof,cmap='gray');
plt.title('Average intensity all time bins');
fig, ax= plt.subplots(1,2,figsize=(12,5))
ax[0].imshow(wtof,cmap='gray'); ax[0].set_title('Average intensity all time bins');
ax[0].plot(57,3,'ro'), ax[0].plot(15,30,'bo'), ax[0].plot(79,90,'go'); ax[0].plot(100,120,'co');
ax[1].plot(tof[30,15,:],'b', label='Sample'); ax[1].plot(tof[3,57,:],'r', label='Background'); ax[1].plot(tof[90,79,:],'g', label='Spacer'); ax[1].legend();ax[1].plot(tof[120,100,:],'c', label='Sample 2');
tofr=tof.reshape([tof.shape[0]*tof.shape[1],tof.shape[2]])
print("Input ToF dimensions",tof.shape)
print("Reshaped ToF data",tofr.shape)
Input ToF dimensions (128, 128, 661) Reshaped ToF data (16384, 661)
km = KMeans(n_clusters=4, random_state=2018)
c = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose() # cluster centroid spectra
Results from the first try
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot
im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
km = KMeans(n_clusters=10, random_state=2018)
c = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose() # cluster centroid spectra
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot
im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
fig,axes = plt.subplots(1,2,figsize=(14,5)); axes=axes.ravel()
axes[0].matshow(np.corrcoef(kc.transpose()))
axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
del km, c, kc, tofr, tof
blob_data, blob_labels = make_blobs(n_samples=100, random_state=2018)
test_pts = pd.DataFrame(blob_data, columns=['x', 'y'])
test_pts['group_id'] = blob_labels
plt.scatter(test_pts.x, test_pts.y, c=test_pts.group_id, cmap='viridis');
orig= fits.getdata('../data/spots/mixture12_00001.fits')
annotated=io.imread('../data/spots/mixture12_00001.png'); mask=(annotated[:,:,1]==0)
r=600; c=600; w=256
ps.magnifyRegion(orig,[r,c,r+w,c+w],[15,7],vmin=400,vmax=4000,title='Neutron radiography')
Parameters
def spotCleaner(img, threshold=0.95, selem=np.ones([3,3])) :
fimg=img.astype('float32')
mimg = flt.median(fimg,selem=selem)
timg = threshold < np.abs(fimg-mimg)
cleaned = mimg * timg + fimg * (1-timg)
return (cleaned,timg)
baseclean,timg = spotCleaner(orig,threshold=1000)
ps.magnifyRegion(baseclean,[r,c,r+w,c+w],[12,3],vmin=400,vmax=4000,title='Cleaned image')
ps.magnifyRegion(timg,[r,c,r+w,c+w],[12,3],vmin=0,vmax=1,title='Detection image')
selem=np.ones([3,3])
forig=orig.astype('float32')
mimg = flt.median(forig,selem=selem)
d = np.abs(forig-mimg)
fig,ax=plt.subplots(1,1,figsize=(8,5))
h,x,y,u=ax.hist2d(forig[:1024,:].ravel(),d[:1024,:].ravel(), bins=100);
ax.imshow(np.log(h[::-1]+1),vmin=0,vmax=3,extent=[x.min(),x.max(),y.min(),y.max()])
ax.set_xlabel('Input image - $f$'),ax.set_ylabel('$|f-med_{3x3}(f)|$'),ax.set_title('Log bivariate histogram');
Training data
trainorig = forig[:,:1000].ravel()
traind = d[:,:1000].ravel()
trainmask = mask[:,:1000].ravel()
train_pts = pd.DataFrame({'orig': trainorig, 'd': traind, 'mask':trainmask})
Test data
testorig = forig[:,1000:].ravel()
testd = d[:,1000:].ravel()
testmask = mask[:,1000:].ravel()
test_pts = pd.DataFrame({'orig': testorig, 'd': testd, 'mask':testmask})
k_class = KNeighborsClassifier(1)
k_class.fit(train_pts[['orig', 'd']], train_pts['mask'])
KNeighborsClassifier(n_neighbors=1)
Inspect decision space
xx, yy = np.meshgrid(np.linspace(test_pts.orig.min(), test_pts.orig.max(), 100),
np.linspace(test_pts.d.min(), test_pts.d.max(), 100),indexing='ij');
grid_pts = pd.DataFrame(dict(x=xx.ravel(), y=yy.ravel()))
grid_pts['predicted_id'] = k_class.predict(grid_pts[['x', 'y']])
plt.scatter(grid_pts.x, grid_pts.y, c=grid_pts.predicted_id, cmap='gray'); plt.title('Testing Points'); plt.axis('square');
pred = k_class.predict(test_pts[['orig', 'd']])
pimg = pred.reshape(d[1000:,:].shape)
fig,ax = plt.subplots(1,3,figsize=(15,6))
ax[0].imshow(forig[1000:,:],vmin=0,vmax=4000), ax[0].set_title('Original image')
ax[1].imshow(pimg), ax[1].set_title('Predicted spot')
ax[2].imshow(mask[1000:,:]),ax[2].set_title('Annotated spots');
cmbase = confusion_matrix(mask[:,1000:].ravel(), timg[:,1000:].ravel(), normalize='all')
cmknn = confusion_matrix(mask[:,1000:].ravel(), pimg.ravel(), normalize='all')
fig,ax = plt.subplots(1,2,figsize=(10,4))
sn.heatmap(cmbase, annot=True,ax=ax[0]), ax[0].set_title('Confusion matrix baseline');
sn.heatmap(cmknn, annot=True,ax=ax[1]), ax[1].set_title('Confusion matrix k-NN');
Note There are other spot detection methods that perform better than the baseline.
del k_class, cmbase, cmknn
import keras.optimizers as opt
import keras.losses as loss
import keras.metrics as metrics
We have two choices:
We will use the spotty image as training data for this example
Any analysis system must be verified to be demonstrate its performance and to further optimize it.
For this we need to split our data into three categories:
| Training | Validation | Test |
|---|---|---|
| 70% | 15% | 15% |
def buildSpotUNet( base_depth = 48) :
in_img = Input((None, None, 1), name='Image_Input')
lay_1 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(in_img)
lay_2 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_1)
lay_3 = MaxPooling2D(pool_size=(2, 2))(lay_2)
lay_4 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_3)
lay_5 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_4)
lay_6 = MaxPooling2D(pool_size=(2, 2))(lay_5)
lay_7 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_6)
lay_8 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_7)
lay_9 = UpSampling2D((2, 2))(lay_8)
lay_10 = concatenate([lay_5, lay_9])
lay_11 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_10)
lay_12 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_11)
lay_13 = UpSampling2D((2, 2))(lay_12)
lay_14 = concatenate([lay_2, lay_13])
lay_15 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_14)
lay_16 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_15)
lay_17 = Conv2D(1, kernel_size=(1, 1), padding='same',
activation='relu')(lay_16)
t_unet = Model(inputs=[in_img], outputs=[lay_17], name='SpotUNET')
return t_unet
Model summary
t_unet = buildSpotUNet(base_depth=24)
t_unet.summary()
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.
Model: "SpotUNET"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Image_Input (InputLayer) (None, None, None, 1 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, None, None, 2 240 Image_Input[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, None, None, 2 5208 conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, None, None, 2 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, None, None, 4 10416 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, None, None, 4 20784 conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, None, None, 4 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, None, None, 9 41568 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, None, None, 9 83040 conv2d_5[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, None, None, 9 0 conv2d_6[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, None, None, 1 0 conv2d_4[0][0]
up_sampling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, None, None, 4 62256 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, None, None, 4 20784 conv2d_7[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, None, None, 4 0 conv2d_8[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, None, None, 7 0 conv2d_2[0][0]
up_sampling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, None, None, 2 15576 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, None, None, 2 5208 conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, None, None, 1 25 conv2d_10[0][0]
==================================================================================================
Total params: 265,105
Trainable params: 265,105
Non-trainable params: 0
__________________________________________________________________________________________________
train_img, valid_img = forig[128:256, 500:1300], forig[500:1000, 300:1500]
train_mask, valid_mask = mask[128:256, 500:1300], mask[500:1000, 300:1500]
wpos = [600,600]; ww = 512
forigc = forig[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
maskc = mask[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
# train_img, valid_img = forig[128:256, 300:1500], forig[500:, 300:1500]
# train_mask, valid_mask = mask[128:256, 300:1500], mask[500:, 300:1500]
fig, ax = plt.subplots(1, 4, figsize=(15, 6), dpi=300); ax=ax.ravel()
ax[0].imshow(train_img, cmap='bone',vmin=0,vmax=4000);ax[0].set_title('Train Image')
ax[1].imshow(train_mask, cmap='bone'); ax[1].set_title('Train Mask')
ax[2].imshow(valid_img, cmap='bone',vmin=0,vmax=4000); ax[2].set_title('Validation Image')
ax[3].imshow(valid_mask, cmap='bone');ax[3].set_title('Validation Mask');
def prep_img(x, n=1):
return (prep_mask(x, n=n)-train_img.mean())/train_img.std()
def prep_mask(x, n=1):
return np.stack([np.expand_dims(x, -1)]*n, 0)
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
fig, m_axs = plt.subplots(2, 3, figsize=(15, 6), dpi=150)
for c_ax in m_axs.ravel():
c_ax.axis('off')
((ax1, _, ax2), (ax3, ax4, ax5)) = m_axs
ax1.imshow(train_img, cmap='bone',vmin=0,vmax=4000); ax1.set_title('Train Image')
ax2.imshow(train_mask, cmap='viridis'); ax2.set_title('Train Mask')
ax3.imshow(forigc, cmap='bone',vmin=0, vmax=4000); ax3.set_title('Test Image')
ax4.imshow(unet_pred, cmap='viridis', vmin=0, vmax=1); ax4.set_title('Predicted Segmentation')
ax5.imshow(maskc, cmap='viridis'); ax5.set_title('Ground Truth');
Another popular metric is the Dice score $$DSC=\frac{2|X \cap Y|}{|X|+|Y|}=\frac{2\,TP}{2TP+FP+FN}$$
mlist = [
metrics.TruePositives(name='tp'), metrics.FalsePositives(name='fp'),
metrics.TrueNegatives(name='tn'), metrics.FalseNegatives(name='fn'),
metrics.BinaryAccuracy(name='accuracy'), metrics.Precision(name='precision'),
metrics.Recall(name='recall'), metrics.AUC(name='auc'),
metrics.MeanAbsoluteError(name='mae')]
t_unet.compile(
loss=loss.BinaryCrossentropy(), # we use the binary cross-entropy to optimize
optimizer=opt.Adam(lr=1e-3), # we use ADAM to optimize
metrics=mlist # we keep track of the metrics in mlist
)
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3172: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
This is a very bad way to train a model;
The goal is to be aware of these techniques and have a feeling for how they can work for complex problems
loss_history = t_unet.fit(prep_img(train_img, n=3),
prep_mask(train_mask, n=3),
validation_data=(prep_img(valid_img),
prep_mask(valid_mask)),
epochs=20,
verbose = 1)
Train on 3 samples, validate on 1 samples Epoch 1/20 3/3 [==============================] - 10s 3s/step - loss: 0.0936 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.5055 - mae: 0.0159 - val_loss: 0.0716 - val_tp: 3.0000 - val_fp: 6.0000 - val_tn: 593510.0000 - val_fn: 6481.0000 - val_accuracy: 0.9892 - val_precision: 0.3333 - val_recall: 4.6268e-04 - val_auc: 0.7423 - val_mae: 0.0153 Epoch 2/20 3/3 [==============================] - 7s 2s/step - loss: 0.0545 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.7486 - mae: 0.0139 - val_loss: 0.0594 - val_tp: 12.0000 - val_fp: 14.0000 - val_tn: 593502.0000 - val_fn: 6472.0000 - val_accuracy: 0.9892 - val_precision: 0.4615 - val_recall: 0.0019 - val_auc: 0.7911 - val_mae: 0.0335 Epoch 3/20 3/3 [==============================] - 7s 2s/step - loss: 0.0514 - tp: 3.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2541.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0012 - auc: 0.8261 - mae: 0.0319 - val_loss: 0.0685 - val_tp: 16.0000 - val_fp: 10.0000 - val_tn: 593506.0000 - val_fn: 6468.0000 - val_accuracy: 0.9892 - val_precision: 0.6154 - val_recall: 0.0025 - val_auc: 0.7511 - val_mae: 0.0164 Epoch 4/20 3/3 [==============================] - 7s 2s/step - loss: 0.0622 - tp: 9.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2535.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0035 - auc: 0.7204 - mae: 0.0132 - val_loss: 0.0602 - val_tp: 24.0000 - val_fp: 19.0000 - val_tn: 593497.0000 - val_fn: 6460.0000 - val_accuracy: 0.9892 - val_precision: 0.5581 - val_recall: 0.0037 - val_auc: 0.8607 - val_mae: 0.0130 Epoch 5/20 3/3 [==============================] - 7s 2s/step - loss: 0.0545 - tp: 18.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2526.0000 - accuracy: 0.9918 - precision: 1.0000 - recall: 0.0071 - auc: 0.8131 - mae: 0.0105 - val_loss: 0.0447 - val_tp: 42.0000 - val_fp: 50.0000 - val_tn: 593466.0000 - val_fn: 6442.0000 - val_accuracy: 0.9892 - val_precision: 0.4565 - val_recall: 0.0065 - val_auc: 0.9338 - val_mae: 0.0199 Epoch 6/20 3/3 [==============================] - 7s 2s/step - loss: 0.0361 - tp: 24.0000 - fp: 3.0000 - tn: 304653.0000 - fn: 2520.0000 - accuracy: 0.9918 - precision: 0.8889 - recall: 0.0094 - auc: 0.9250 - mae: 0.0166 - val_loss: 0.0441 - val_tp: 47.0000 - val_fp: 54.0000 - val_tn: 593462.0000 - val_fn: 6437.0000 - val_accuracy: 0.9892 - val_precision: 0.4653 - val_recall: 0.0072 - val_auc: 0.9319 - val_mae: 0.0183 Epoch 7/20 3/3 [==============================] - 7s 2s/step - loss: 0.0372 - tp: 27.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2517.0000 - accuracy: 0.9918 - precision: 0.7500 - recall: 0.0106 - auc: 0.9185 - mae: 0.0184 - val_loss: 0.0463 - val_tp: 50.0000 - val_fp: 50.0000 - val_tn: 593466.0000 - val_fn: 6434.0000 - val_accuracy: 0.9892 - val_precision: 0.5000 - val_recall: 0.0077 - val_auc: 0.9219 - val_mae: 0.0146 Epoch 8/20 3/3 [==============================] - 7s 2s/step - loss: 0.0343 - tp: 27.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2517.0000 - accuracy: 0.9918 - precision: 0.7500 - recall: 0.0106 - auc: 0.9329 - mae: 0.0129 - val_loss: 0.0466 - val_tp: 55.0000 - val_fp: 48.0000 - val_tn: 593468.0000 - val_fn: 6429.0000 - val_accuracy: 0.9892 - val_precision: 0.5340 - val_recall: 0.0085 - val_auc: 0.9266 - val_mae: 0.0123 Epoch 9/20 3/3 [==============================] - 7s 2s/step - loss: 0.0348 - tp: 36.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2508.0000 - accuracy: 0.9918 - precision: 0.8000 - recall: 0.0142 - auc: 0.9309 - mae: 0.0097 - val_loss: 0.0421 - val_tp: 71.0000 - val_fp: 56.0000 - val_tn: 593460.0000 - val_fn: 6413.0000 - val_accuracy: 0.9892 - val_precision: 0.5591 - val_recall: 0.0110 - val_auc: 0.9434 - val_mae: 0.0136 Epoch 10/20 3/3 [==============================] - 7s 2s/step - loss: 0.0327 - tp: 60.0000 - fp: 12.0000 - tn: 304644.0000 - fn: 2484.0000 - accuracy: 0.9919 - precision: 0.8333 - recall: 0.0236 - auc: 0.9435 - mae: 0.0108 - val_loss: 0.0401 - val_tp: 97.0000 - val_fp: 82.0000 - val_tn: 593434.0000 - val_fn: 6387.0000 - val_accuracy: 0.9892 - val_precision: 0.5419 - val_recall: 0.0150 - val_auc: 0.9567 - val_mae: 0.0174 Epoch 11/20 3/3 [==============================] - 7s 2s/step - loss: 0.0312 - tp: 72.0000 - fp: 30.0000 - tn: 304626.0000 - fn: 2472.0000 - accuracy: 0.9919 - precision: 0.7059 - recall: 0.0283 - auc: 0.9574 - mae: 0.0146 - val_loss: 0.0396 - val_tp: 108.0000 - val_fp: 82.0000 - val_tn: 593434.0000 - val_fn: 6376.0000 - val_accuracy: 0.9892 - val_precision: 0.5684 - val_recall: 0.0167 - val_auc: 0.9533 - val_mae: 0.0127 Epoch 12/20 3/3 [==============================] - 7s 2s/step - loss: 0.0295 - tp: 75.0000 - fp: 36.0000 - tn: 304620.0000 - fn: 2469.0000 - accuracy: 0.9918 - precision: 0.6757 - recall: 0.0295 - auc: 0.9546 - mae: 0.0101 - val_loss: 0.0365 - val_tp: 149.0000 - val_fp: 96.0000 - val_tn: 593420.0000 - val_fn: 6335.0000 - val_accuracy: 0.9893 - val_precision: 0.6082 - val_recall: 0.0230 - val_auc: 0.9649 - val_mae: 0.0133 Epoch 13/20 3/3 [==============================] - 7s 2s/step - loss: 0.0273 - tp: 114.0000 - fp: 54.0000 - tn: 304602.0000 - fn: 2430.0000 - accuracy: 0.9919 - precision: 0.6786 - recall: 0.0448 - auc: 0.9680 - mae: 0.0107 - val_loss: 0.0393 - val_tp: 206.0000 - val_fp: 148.0000 - val_tn: 593368.0000 - val_fn: 6278.0000 - val_accuracy: 0.9893 - val_precision: 0.5819 - val_recall: 0.0318 - val_auc: 0.9756 - val_mae: 0.0218 Epoch 14/20 3/3 [==============================] - 7s 2s/step - loss: 0.0339 - tp: 159.0000 - fp: 93.0000 - tn: 304563.0000 - fn: 2385.0000 - accuracy: 0.9919 - precision: 0.6310 - recall: 0.0625 - auc: 0.9741 - mae: 0.0213 - val_loss: 0.0376 - val_tp: 177.0000 - val_fp: 102.0000 - val_tn: 593414.0000 - val_fn: 6307.0000 - val_accuracy: 0.9893 - val_precision: 0.6344 - val_recall: 0.0273 - val_auc: 0.9585 - val_mae: 0.0120 Epoch 15/20 3/3 [==============================] - 7s 2s/step - loss: 0.0296 - tp: 132.0000 - fp: 69.0000 - tn: 304587.0000 - fn: 2412.0000 - accuracy: 0.9919 - precision: 0.6567 - recall: 0.0519 - auc: 0.9548 - mae: 0.0096 - val_loss: 0.0433 - val_tp: 165.0000 - val_fp: 85.0000 - val_tn: 593431.0000 - val_fn: 6319.0000 - val_accuracy: 0.9893 - val_precision: 0.6600 - val_recall: 0.0254 - val_auc: 0.9356 - val_mae: 0.0114 Epoch 16/20 3/3 [==============================] - 7s 2s/step - loss: 0.0346 - tp: 129.0000 - fp: 54.0000 - tn: 304602.0000 - fn: 2415.0000 - accuracy: 0.9920 - precision: 0.7049 - recall: 0.0507 - auc: 0.9272 - mae: 0.0091 - val_loss: 0.0432 - val_tp: 169.0000 - val_fp: 84.0000 - val_tn: 593432.0000 - val_fn: 6315.0000 - val_accuracy: 0.9893 - val_precision: 0.6680 - val_recall: 0.0261 - val_auc: 0.9349 - val_mae: 0.0113 Epoch 17/20 3/3 [==============================] - 7s 2s/step - loss: 0.0342 - tp: 132.0000 - fp: 54.0000 - tn: 304602.0000 - fn: 2412.0000 - accuracy: 0.9920 - precision: 0.7097 - recall: 0.0519 - auc: 0.9257 - mae: 0.0090 - val_loss: 0.0363 - val_tp: 216.0000 - val_fp: 110.0000 - val_tn: 593406.0000 - val_fn: 6268.0000 - val_accuracy: 0.9894 - val_precision: 0.6626 - val_recall: 0.0333 - val_auc: 0.9623 - val_mae: 0.0125 Epoch 18/20 3/3 [==============================] - 7s 2s/step - loss: 0.0297 - tp: 150.0000 - fp: 81.0000 - tn: 304575.0000 - fn: 2394.0000 - accuracy: 0.9919 - precision: 0.6494 - recall: 0.0590 - auc: 0.9514 - mae: 0.0101 - val_loss: 0.0375 - val_tp: 263.0000 - val_fp: 142.0000 - val_tn: 593374.0000 - val_fn: 6221.0000 - val_accuracy: 0.9894 - val_precision: 0.6494 - val_recall: 0.0406 - val_auc: 0.9667 - val_mae: 0.0188 Epoch 19/20 3/3 [==============================] - 7s 2s/step - loss: 0.0299 - tp: 186.0000 - fp: 99.0000 - tn: 304557.0000 - fn: 2358.0000 - accuracy: 0.9920 - precision: 0.6526 - recall: 0.0731 - auc: 0.9610 - mae: 0.0153 - val_loss: 0.0410 - val_tp: 314.0000 - val_fp: 173.0000 - val_tn: 593343.0000 - val_fn: 6170.0000 - val_accuracy: 0.9894 - val_precision: 0.6448 - val_recall: 0.0484 - val_auc: 0.9669 - val_mae: 0.0243 Epoch 20/20 3/3 [==============================] - 7s 2s/step - loss: 0.0336 - tp: 204.0000 - fp: 114.0000 - tn: 304542.0000 - fn: 2340.0000 - accuracy: 0.9920 - precision: 0.6415 - recall: 0.0802 - auc: 0.9638 - mae: 0.0206 - val_loss: 0.0391 - val_tp: 326.0000 - val_fp: 183.0000 - val_tn: 593333.0000 - val_fn: 6158.0000 - val_accuracy: 0.9894 - val_precision: 0.6405 - val_recall: 0.0503 - val_auc: 0.9711 - val_mae: 0.0224
titleDict = {'tp': "True Positives",'fp': "False Positives",'tn': "True Negatives",'fn': "False Negatives", 'accuracy':"BinaryAccuracy",'precision': "Precision",'recall':"Recall",'auc': "Area under Curve", 'mae': "Mean absolute error"}
fig,ax = plt.subplots(2,5, figsize=(20,8), dpi=300)
ax =ax.ravel()
for idx,key in enumerate(titleDict.keys()):
ax[idx].plot(loss_history.epoch, loss_history.history[key], color='coral', label='Training')
ax[idx].plot(loss_history.epoch, loss_history.history['val_'+key], color='cornflowerblue', label='Validation')
ax[idx].set_title(titleDict[key]);
ax[9].axis('off');
axLine, axLabel = ax[0].get_legend_handles_labels() # Take the lables and plot line information from the first panel
lines =[]; labels = []; lines.extend(axLine); labels.extend(axLabel);fig.legend(lines, labels, bbox_to_anchor=(0.7, 0.3), loc='upper left');
unet_train_pred = t_unet.predict(prep_img(train_img[:,wpos[1]:(wpos[1]+ww)]))[0, :, :, 0]
fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs= m_axs.ravel();
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(train_img[:,wpos[1]:(wpos[1]+ww)], cmap='bone', vmin=0, vmax=4000), m_axs[0].set_title('Train Image')
m_axs[1].imshow(unet_train_pred, cmap='viridis', vmin=0, vmax=0.2), m_axs[1].set_title('Predicted Training')
m_axs[2].imshow(train_mask[:,wpos[1]:(wpos[1]+ww)], cmap='viridis'), m_axs[2].set_title('Train Mask');
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs = m_axs.ravel() ;
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(forigc, cmap='bone', vmin=0, vmax=4000); m_axs[0].set_title('Full Image')
f1=m_axs[1].imshow(unet_pred, cmap='viridis', vmin=0, vmax=0.1); m_axs[1].set_title('Predicted Segmentation'); fig.colorbar(f1,ax=m_axs[1]);
m_axs[2].imshow(maskc,cmap='viridis'); m_axs[2].set_title('Ground Truth');
fig, ax = plt.subplots(1,2, figsize=(12,4))
ax0=ax[0].imshow(unet_pred, vmin=0, vmax=0.1); ax[0].set_title('Predicted segmentation'); fig.colorbar(ax0,ax=ax[0])
ax[1].imshow(0.05<unet_pred), ax[1].set_title('Final segmenation');
gt = maskc
pr = 0.05<unet_pred
ps.showHitCases(gt,pr,cmap='gray')
fig, ax = plt.subplots(1,2,figsize=(12,4))
ps.showHitMap(gt,pr,ax=ax)